from src.utils import *

from scipy.linalg import solve as linear_solver
from scipy.optimize import minimize_scalar


class RidgeLimit(object):
    """ Simulating the theoretical results for the limiting risk of ridge regression. Note that we specify here a model
    with finite dimensions (d, l) based on which we compute the limiting risk for d,l\to\infty."""

    def __init__(self, alpha=None, beta=None, sigma_sq=None, M=None, stat_model_params=None):
        """Parameters  that specify the model.

        Attributes:
            alpha (np.array, shape=l): linear influence of confounder z on y. If a tuple (l, r) = (int, float) is given, alpha is
                                       drawn from the uniform distribution on sphere in l dimensions with radius r.
            beta (np.array, shape=d): linear influence of covariates x on y. If a tuple (d, r) = (int, float) is given,
                                      beta is drawn from the uniform distribution on sphere in d dimensions with radius r.
            sigma_sq (np.float, positive): noise variance
            M (np.array, shape=(d, l)): linear influence of confounder z on covariates x. If None, M is generated such
                                        that Sigma=MM^T=I with unit singular values and random orthogonal matrices.
            stat_model_params (tuple): Either None or stat_model=(r_stat, sigma_sq_stat, conf_strength, d, l). If given,
                                       the causal model parameters (alpha, beta, sigma_sq) are chosen to match the
                                       specified statistical model parameters, confounding strength, and dimensions.
        """
        # Vectors and matrices that define the model
        if stat_model_params is None:
            if type(alpha) == tuple:
                alpha = draw_spherical(*alpha)
            self.alpha = alpha
            if type(beta) == tuple:
                beta = draw_spherical(*beta)
            self.beta = beta
            self.d, self.l = len(self.beta), len(self.alpha)
        else:
            r_stat, sigma_sq_stat, conf_strength, d, l = stat_model_params
            s, r, sigma_sq = self.causal_params_from_statistical(r_stat, sigma_sq_stat, conf_strength, d, l)
            self.alpha = draw_spherical(l, s)
            self.beta = draw_spherical(d, r)
            self.d, self.l = d, l

        # Distribution of latent variables, initialized as Gaussian
        self.sample_z = lambda n_samples: np.random.standard_normal(size=(n_samples, self.l))

        if M is None:
            M = self.generate_M((self.d, self.l))
        assert M.shape == (self.d, self.l), "Shapes of M don't match alpha and beta!"
        self.M = M
        self.Sigma = self.M @ self.M.T
        assert np.isclose(self.Sigma, np.eye(self.d)).all(), "Sigma is not identity!"
        self.Gamma = self.M @ self.alpha

        # Signal and noise strengths
        self.sigma_sq = sigma_sq
        self.r_sq = np.inner(self.beta, self.beta)
        self.norm_alpha_sq = np.inner(self.alpha, self.alpha)
        self.omega_sq = np.inner(self.Gamma, self.Gamma)
        self.eta = np.inner(self.beta, self.Gamma)


        # Quantities of corresponding statistical model
        self.beta_stat = np.linalg.pinv(self.Sigma) @ (self.Sigma @ self.beta + self.Gamma)
        self.sigma_sq_stat = self.sigma_sq + self.norm_alpha_sq - self.Gamma.T @ self.Sigma @ self.Gamma
        self.r_sq_stat = np.inner(self.beta_stat, self.beta_stat)

        self.conf_strength = self.omega_sq / (self.omega_sq + self.r_sq)
        self.conf_strength_eta = (self.omega_sq + self.eta) / self.r_sq_stat
        self.SNR_caus = (self.r_sq + self.eta) / self.sigma_sq_stat
        self.SNR_stat = self.r_sq_stat / self.sigma_sq_stat


        self.bayes_risk_caus = self.sigma_sq + self.norm_alpha_sq
        self.bayes_risk_stat = self.sigma_sq_stat
        self.include_bayes = False

        self.optimizer_bounds = (0, 300)

    def sample(self, n_samples):
        """Generates samples X of shape (n, d) and Y of shape (n) from the observational distribution.
        The function sample_z controls the distribution of the latent random variable z. If None, it is standard Gaussian."""
        # Z = self.sample_z(n_samples)
        # X = Z @ self.M.T
        # eps = np.random.normal(loc=0, scale=np.sqrt(self.sigma_sq_stat), size=n_samples).reshape(-1, 1)
        # Y = X @ self.beta_stat.reshape(-1, 1) + eps

        Z = self.sample_z(n_samples)
        X = Z @ self.M.T
        eps = np.random.normal(loc=0, scale=np.sqrt(self.sigma_sq), size=n_samples).reshape(-1, 1)
        Y = X @ self.beta.reshape(-1, 1) + Z @ self.alpha.reshape(-1, 1) + eps
        return X, Y

    def get_beta_ridge(self, X, Y, lam):
        """Computes the ridge regression solution for a given train set (X,Y) with regularization lam."""
        n = X.shape[0]
        return linear_solver(a=X.T @ X / n + lam * np.eye(self.d), b=X.T @ Y / n, assume_a='pos').flatten()

    def causal_params_from_statistical(self, r_stat, sigma_sq_stat, conf_strength, d, l):
        """Given a set of target parameters for the statistical model, outputs corresponding causal parameters.
        Outputs ||alpha||=s, ||beta||=r, noise sigma_sq."""
        s = np.sqrt(l / d * r_stat**2 * conf_strength)
        r = np.sqrt(r_stat**2 - d / l * s**2)
        sigma_sq = sigma_sq_stat - (l - d) / d * s**2
        assert sigma_sq >= 0, "Negative noise variance! Reduce l/d, r_stat, or conf_strength."
        return s, r, sigma_sq

    def generate_M(self, shape):
        """Generate a random matrix M of given shape such that M @ M.T = I ."""
        d, l = shape[0], shape[1]
        assert d <= l, "Invalid dimensions: l needs to be larger than d for generating isotropic Sigma!"

        U, V = ortho_group.rvs(d), ortho_group.rvs(l)
        Lambda = np.concatenate((np.eye(d), np.zeros((d, l-d))), axis=1)
        M = (U @ Lambda) @ V.T
        return M

    def mp(self, z, gam):
        """Evaluates m(z), where m is the Stieltjes transform of the MP-law with overparameterization ratio gam."""
        val = (1 - gam - z - np.sqrt((1 - gam - z)**2 - 4 * gam * z)) / (2 * gam * z)
        return val

    def mp_deriv(self, z, gam):
        """Evaluates m'(z), where m is the Stieltjes transform of the MP-law with overparameterization ratio gam."""
        val = (gam - 1) / (2 * gam * z**2) + ((gam - 1)**2 - (gam + 1) * z) / \
                    (2 * gam * z**2 * np.sqrt((gam - z - 1)**2 - 4 * z))
        return val

    def M_0(self, z, gam):
        """Evaluates z*m(-z), where m is the Stieltjes transform of the MP-law with overparameterization ratio gam."""
        val = (gam - 1) / (2 * gam) - z / (2 * gam) + np.sqrt((1 - gam + z)**2 + 4 * gam * z) / (2 * gam)
        return val

    def M_1(self, z, gam):
        """Evaluates z**2 * m'(-z), where m is the Stieltjes transform of the MP-law with overparameterization ratio gam."""
        val = (gam - 1) / (2 * gam) + ((gam - 1)**2 + (gam + 1) * z) / (2 * gam * np.sqrt((gam + z - 1)**2 + 4 * z))
        return val

    def M_2(self, z, gam):
        """Evaluates m(-z) - z * m'(-z), where m is the Stieltjes transform of the MP-law with overparameterization ratio gam."""
        val = - 1 / 2 + (1 + gam + z) / (2 * np.sqrt((1 - gam + z)**2 + 4 * gam * z))
        return val

    def bias_caus(self, lam, gam):
        """Evalutes the limiting causal bias B_X^C of the ridge regression solution with regularization lam and
        overparameterization ration gam."""
        if gam == 0:
            gam = 1e-7
        if lam == 0:
            bias = self.omega_sq + (self.r_sq - self.omega_sq) * (gam - 1 + np.abs(gam - 1)) / (2 * gam)
        else:
            # bias = self.omega_sq + (self.r_sq_stat - 4 * self.eta) * self.M_1(lam, gam) \
            #        + 2 * (self.eta - self.omega_sq) * self.M_0(lam, gam)
            bias = self.omega_sq + self.r_sq_stat * self.M_1(lam, gam) - 2 * (self.eta + self.omega_sq) * self.M_0(lam, gam)
        return bias

    def var(self, lam, gam):
        """Evalutes the limiting causal/statistical variance V_X^C of the ridge regression solution with regularization
        lam and overparameterization ration gam."""
        if gam == 0:
            gam = 1e-7
        if lam == 0:
            variance = self.sigma_sq_stat * (1 + gam - np.abs(1 - gam)) / (2 * np.abs(1 - gam))
        else:
            variance = self.sigma_sq_stat * self.M_2(lam, gam)
        return variance

    def risk_caus(self, lam, gam):
        """Evalutes the limiting causal risk R_X^C of the ridge regression solution with regularization lam and
        overparameterization ration gam."""
        risk = self.bias_caus(lam, gam) + self.var(lam, gam)
        if self.include_bayes:
            risk += self.bayes_risk_caus
        return risk

    def bias_stat(self, lam, gam):
        """Evalutes the limiting statistical bias B_X^C of the ridge regression solution with regularization lam and
        overparameterization ration gam."""
        if gam == 0:
            gam = 1e-7
        if lam == 0:
            bias = self.r_sq_stat * (gam - 1 + np.abs(gam - 1)) / (2 * gam)
        else:
            bias = self.r_sq_stat * self.M_1(lam, gam)
        return bias

    def risk_stat(self, lam, gam):
        """Evalutes the limiting statistical risk R_X^C of the ridge regression solution with regularization lam and
        overparameterization ration gam."""
        risk = self.bias_stat(lam, gam) + self.var(lam, gam)
        if self.include_bayes:
            risk += self.bayes_risk_stat
        return risk

    def oracle_lam_caus(self, gam, optimize=True):
        """Returns optimal causal ridge parameter lam^ast and corresponding risk for a given overparameterization
        ratio gam. If optimize is True, the optimal lambda is obtained by optimization, otherwise, it uses the large
        gamma limit obtained by theory."""
        if optimize:
            risk_fn = lambda lam: self.risk_caus(lam=lam, gam=gam)
            lam_optim = minimize_scalar(risk_fn, method='bounded', bounds=self.optimizer_bounds).x
            # lam_optim = minimize_scalar(risk_fn, method='brent', bracket=(0, 40)).x
        else:
            a = self.r_sq - self.omega_sq
            b = 2 * (self.eta - self.omega_sq)
            c = self.sigma_sq_stat
            A = (self.sigma_sq_stat - self.eta + self.omega_sq) / (self.r_sq - self.eta)
            # B = -b * (A-2) / ((2*a-b)*(A+1))
            # B = -b * (2*c+b-4*a) / (2*(a+c-b)*(2*a-b))
            B = -(self.eta - self.omega_sq) * (self.sigma_sq_stat + self.eta - 2*self.r_sq + self.omega_sq) / ((self.r_sq - self.eta)*(self.r_sq+self.omega_sq+self.sigma_sq_stat-2*self.eta))
            lam_optim = A * gam + B

        risk_optim = self.risk_caus(lam_optim, gam)
        return lam_optim, risk_optim

    def oracle_lam_stat(self, gam):
        """Returns optimal statistical ridge parameter lam^ast and corresponding risk for a given overparameterization
        ratio gam."""
        lam_optim = gam * self.sigma_sq_stat / self.r_sq_stat
        risk_optim = self.risk_stat(lam_optim, gam)
        return lam_optim, risk_optim

    def get_loss_test(self, beta_hat, X_test, Y_test):
        """Computes the squared loss based on a given test set (X_test, Y_test)."""
        return np.mean((X_test @ beta_hat - Y_test.flatten()) ** 2)

    def get_loss_caus(self, beta_hat):
        """Evaluates the causal loss of a given predictor (without Bayes risk)."""
        return weighted_dot(beta_hat - self.beta, A=self.Sigma)

    def get_loss_stat(self, beta_hat):
        """Evaluates the statistical loss of a given predictor (without Bayes risk)."""
        return weighted_dot(beta_hat - self.beta_stat, A=self.Sigma)

    def get_finite_risks(self, lam, gam):
        """Generates a random sample (X,Y) from the model with n=d/gam many samples, and computes both losses for the
        ridge regression solution with regularization lam."""
        n = int(self.d / gam)
        # Add small regularization in ridgeless case for numeric stability
        if lam == 0:
            lam = 1e-7

        # Use the statistical model directly to sample training data, which avoids the larger-dimensional z
        X, Y = self.sample(n_samples=n)
        beta_ridge = self.get_beta_ridge(X, Y, lam)
        loss_caus = self.get_loss_caus(beta_ridge)
        loss_stat = self.get_loss_stat(beta_ridge)
        return loss_caus, loss_stat

    def compute_risks(self, lam, gams, include_finite=False):
        """Compute several quantities of interest for a sequence gams of overparameterization ratios and a regularization
        parameter lam. If lam is 'optim_caus' or 'optim_stat', it is chosen optimally with respect to causal or statistical
        risk for each gamma."""
        # Initialize dictionary
        res_dict = dict.fromkeys(['bias_caus',
                                  'bias_stat',
                                  'var',
                                  'risk_caus',
                                  'risk_stat',
                                  'finite_loss_caus',
                                  'finite_loss_stat'])
        for key in res_dict.keys():
            res_dict[key] = []

        # Loop over overparameterization ratios in gams
        for gam in gams:
            if lam == 'optim_caus':
                lam_gam, _ = self.oracle_lam_caus(gam=gam, optimize=True)
            elif lam == 'optim_stat':
                lam_gam, _ = self.oracle_lam_stat(gam=gam)
            else:
                lam_gam = lam
            res_dict['bias_caus'].append(self.bias_caus(lam_gam, gam))
            res_dict['bias_stat'].append(self.bias_stat(lam_gam, gam))
            res_dict['var'].append(self.var(lam_gam, gam))
            res_dict['risk_caus'].append(self.risk_caus(lam_gam, gam))
            res_dict['risk_stat'].append(self.risk_stat(lam_gam, gam))
            if include_finite:
                loss_caus, loss_stat = self.get_finite_risks(lam_gam, gam)
                res_dict['finite_loss_caus'].append(loss_caus)
                res_dict['finite_loss_stat'].append(loss_stat)

        return res_dict

    def compute_oracle(self, gams, include_finite=False):
        """Compute oracle regularization parameters and risks for a sequence gams of overparameterization ratios."""
        # Initialize dictionary
        oracle_dict = dict.fromkeys(['lam_optim_caus',
                                     'lam_optim_caus_asymp',
                                     'lam_optim_stat',
                                     'bias_optim_caus',
                                     'variance_optim_caus',
                                     'finite_loss_optim_caus',
                                     'risk_optim_caus',
                                     'risk_optim_caus_asymp',
                                     'risk_optim_stat'])
        for key in oracle_dict.keys():
            oracle_dict[key] = []

        # Loop over overparameterization ratios in gams
        for gam in gams:
            lam_optim_caus, risk_optim_caus = self.oracle_lam_caus(gam, optimize=True)
            bias_optim_caus = self.bias_caus(lam=lam_optim_caus, gam=gam)
            variance_optim_caus = self.var(lam=lam_optim_caus, gam=gam)
            lam_optim_caus_asymp, risk_optim_caus_asymp = self.oracle_lam_caus(gam, optimize=False)
            lam_optim_stat, risk_optim_stat = self.oracle_lam_stat(gam)
            oracle_dict['lam_optim_caus'].append(lam_optim_caus)
            oracle_dict['lam_optim_caus_asymp'].append(lam_optim_caus_asymp)
            oracle_dict['lam_optim_stat'].append(lam_optim_stat)
            oracle_dict['bias_optim_caus'].append(bias_optim_caus)
            oracle_dict['variance_optim_caus'].append(variance_optim_caus)
            oracle_dict['risk_optim_caus'].append(risk_optim_caus)
            oracle_dict['risk_optim_caus_asymp'].append(risk_optim_caus_asymp)
            oracle_dict['risk_optim_stat'].append(risk_optim_stat)
            if include_finite:
                loss_optim_caus, _ = self.get_finite_risks(lam_optim_caus, gam)
                oracle_dict['finite_loss_optim_caus'].append(loss_optim_caus)
        return oracle_dict

    def run_cv(self, gams, lams, num_splits='loo', num_runs=3):
        """For each gam, performs cross validation for the statistical risk on lams with num_splits many splits on a
        randomly drawn validation set of n_val=d/gam many samples. If num_splits='loo', then the leave-one-out method is
        used with the shortcut formula given in the surprises paper.
        Args:
            num_runs (int): number of random training sets, based on which lam_cv is computed. Output lam_cv is the average
                            over individual runs.
        """
        list_lam_cv = [[] for _ in range(len(gams))]
        list_risk_cv = [[] for _ in range(len(gams))]
        for i, gam in enumerate(gams):
            n_val = int(self.d / gam)
            for _ in range(num_runs):
                # Generate validation set of size corresponding to overparameterization ratio
                X, Y = self.sample(n_samples=n_val)
                # cv_scores = []
                # for lam in lams:
                #     if lam == 0:
                #         lam = 1e-7
                #     if num_splits is not 'loo':
                #         # Loop over different choices of lambda
                #         score = 0
                #         kfold = KFold(n_splits=num_splits, random_state=None, shuffle=True)
                #         # Compute scores across all splits for a given lambda
                #         for train_index, test_index in kfold.split(X):
                #             x_train_cv, x_test_cv = X[train_index], X[test_index]
                #             y_train_cv, y_test_cv = Y[train_index], Y[test_index]
                #             beta_hat = self.get_beta_ridge(x_train_cv, y_train_cv, lam=lam)
                #             score += self.get_loss_test(beta_hat, x_test_cv, y_test_cv) / num_splits
                #         cv_scores.append(score)
                #     else: # Leave-one-out validation with shortcut formula
                #         ridge_smoother_matrix = (X @ np.linalg.pinv(X.T @ X + n_val * lam * np.eye(self.d))) @ X.T
                #         beta_hat = self.get_beta_ridge(X, Y, lam=lam)
                #         score = np.mean(((Y.flatten() - X @ beta_hat) / (np.ones(n_val) - np.diag(ridge_smoother_matrix))) ** 2)
                #         cv_scores.append(score)
                # # Lambda chosen by cv for each gamma
                # list_lam_cv[i] += [lams[np.argmin(cv_scores)]]
                list_lam_cv[i] += [cv(X, Y, lams, loss_fn=self.get_loss_test, num_splits=num_splits)]
            # After list_lam_cv[i] now contains a list [lam_1, ..., lam_num_runs] of cv-chosen lams across multiple train sets
            # Collapse to mean value and compute corresponding risk
            list_lam_cv[i] = np.mean(np.array(list_lam_cv[i]))
            list_risk_cv[i] = self.risk_caus(list_lam_cv[i], gam)
        return list_lam_cv, list_risk_cv